{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ff3392cf-f1c6-4067-9213-0eecca4e70f6",
   "metadata": {},
   "source": [
    "# Use LLM for metadata extraction from a query\n",
    "\n",
    "This demo uses a large language model for metadata extraction from a query.\n",
    "\n",
    "Uses OpenAI model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13029c01-a035-4aae-9291-9d4f163252ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install openai==1.39.0 pydantic==2.8.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "59aa655f-6015-4620-88b7-c38658e21c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "from ipython_secrets import get_secret\n",
    "import os\n",
    "\n",
    "secret = get_secret('OPENAI_API_KEY')\n",
    "os.environ[\"OPENAI_API_KEY\"] = secret\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8455e9be-fc4b-49cc-ae19-37a5c66e89a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Query: what's a recipe for vegetarian spaghetti?\n",
      "Topic: recipe\n",
      "---\n",
      "Query: what is the best way to poach an egg?\n",
      "Topic: cooking_technique\n",
      "---\n",
      "Query: What blender setting should I use to make a fruit smoothie?\n",
      "Topic: equipment\n",
      "---\n",
      "Query: Can you give me a recipe for chocolate chip cookies?\n",
      "Topic: recipe\n",
      "---\n",
      "Query: Why is the sky blue?\n",
      "Topic: None\n",
      "---\n"
     ]
    }
   ],
   "source": [
    "from openai import OpenAI\n",
    "from pydantic import BaseModel\n",
    "import json\n",
    "from typing import Literal, Optional\n",
    "\n",
    "# Create client to call model\n",
    "api_key = os.environ[\"OPENAI_API_KEY\"]\n",
    "client = OpenAI(\n",
    "    api_key=api_key,\n",
    ")\n",
    "\n",
    "# Create classifier\n",
    "class ContentTopic(BaseModel):\n",
    "    topic: Optional[Literal[\n",
    "        \"nutritional_information\",\n",
    "        \"equipment\",\n",
    "        \"cooking_technique\",\n",
    "        \"recipe\"\n",
    "    ]]\n",
    "\n",
    "function_definition = {\n",
    "    \"name\": \"classify_topic\",\n",
    "    \"description\": \"Extract the key topics from the query\",\n",
    "    \"parameters\": json.loads(ContentTopic.schema_json())\n",
    "}\n",
    "# The topic classifier uses few-shot examples to optimize the classification task.\n",
    "def get_topic(query: str):\n",
    "    response = client.chat.completions.create(\n",
    "    model=\"gpt-4o-mini\", \n",
    "    functions=[function_definition], \n",
    "    function_call={ \"name\": function_definition[\"name\"] },\n",
    "    temperature=0,\n",
    "    messages=[\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": \"\"\"Extract the topic of the following user query about cooking.\n",
    "Only use the topics present in the content topic classifier function.\n",
    "If you cannot tell the query topic or it is not about cooking, respond `null`. Output JSON.\n",
    "You MUST choose one of the given content topic types.\n",
    "Example 1:\n",
    "User:  \"How many grams of sugar are in a banana?\"\n",
    "Assistant: '{\"topic\": \"nutritional_information\"}'\n",
    "Example 2:\n",
    "User: \"What are the ingredients for a classic margarita?\"\n",
    "Assistant: '{\"topic\": \"recipe\"}'\n",
    "Example 3:\n",
    "User: \"What kind of knife is best for chopping vegetables?\"\n",
    "Assistant: '{\"topic\": \"equipment\"}'\n",
    "Example 4:\n",
    "User: \"What is a quick recipe for chicken stir-fry?\"\n",
    "Assistant: '{\"topic\": \"recipe\"}'\n",
    "Example 5:\n",
    "User: Who is the best soccer player ever?\n",
    "Assistant: '{\"topic\": null}'\n",
    "Example 6:\n",
    "User: Explain gravity to me\n",
    "Assistant: '{\"topic\": null}'\"\"\",\n",
    "          },\n",
    "          {\n",
    "              \"role\": \"user\",\n",
    "              \"content\": query\n",
    "          }\n",
    "      ],\n",
    "    )\n",
    "    content = ContentTopic.model_validate(json.loads(response.choices[0].message.function_call.arguments)) \n",
    "    return content.topic\n",
    "\n",
    "## Test the classifier\n",
    "queries = [\n",
    "    \"what's a recipe for vegetarian spaghetti?\",\n",
    "    \"what is the best way to poach an egg?\",\n",
    "    \"What blender setting should I use to make a fruit smoothie?\",\n",
    "    \"Can you give me a recipe for chocolate chip cookies?\",\n",
    "    \"Why is the sky blue?\"\n",
    "]\n",
    "for query in queries:\n",
    "    print(f\"Query: {query}\")\n",
    "    print(f\"Topic: {get_topic(query)}\")\n",
    "    print(\"---\")\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b402fbc-a427-442b-9491-194035c4216a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
